import tkinter as tk
from tkinter import ttk
import numpy as np
import math
import re
from OpenGL import GL, GLU, GLUT
from OpenGL.GL.shaders import compileShader, compileProgram
from pyopengltk import OpenGLFrame

# --- HDGL Machine (Modified for GPU Compute) ---
class HDGLMachine:
    def __init__(self):
        # Upper Field
        self.upper_field = {
            "prism_state": 105.0,
            "recursion_mode": 99.9999999999,
            "positive_infinite": 9.9999999999,
            "P3": 4.2360679775,
            "pi": 3.1415926535,
            "phi_power_phi": 2.6180339887,
            "phi": 1.6180339887
        }
        # Analog Dimensions
        self.analog_dims = {
            "D8":8.0,"D7":7.0,"D6":6.0,"D5":5.0,"D4":4.0,"D3":3.0,"D2":2.0,"D1":1.0,
            "dim_switch":1.0,
            "P4":6.8541019662,"P5":11.0901699437,"P6":17.94427191,"P7":29.0344465435
        }
        # Lower Field
        self.lower_field = {
            "negative_infinite":1e-10,
            "inv_P7":0.0344465435,
            "inv_P6":0.05572809,
            "inv_P5":0.0901699437,
            "inv_P4":0.1458980338,
            "inv_P3":0.2360679775,
            "inv_phi_power_phi":0.3819660113,
            "inv_phi":0.6180339887
        }

        self.void = 0.0
        self.current_state = self.void
        self.dimension = self.analog_dims["dim_switch"]
        self.recursion_active = False
        self.use_hdgl_base = False
        self.seed = 1.0

        # GPU buffers for fields
        self.upper_field_values = np.array(list(self.upper_field.values()), dtype=np.float32)
        self.lower_field_values = np.array(list(self.lower_field.values()), dtype=np.float32)
        self.analog_dims_values = np.array(list(self.analog_dims.values()), dtype=np.float32)

    # --- Toggles ---
    def toggle_recursion(self):
        self.recursion_active = not self.recursion_active
        print(f"Recursion: {'ON' if self.recursion_active else 'OFF'}")

    def toggle_dimension(self):
        self.dimension = 1.0 if self.dimension != 1.0 else 0.0
        print(f"Dimension: {'2D double-helix' if self.dimension == 1.0 else '1D'}")

    def toggle_base(self):
        self.use_hdgl_base = not self.use_hdgl_base
        print(f"HDGL Base: {'ON' if self.use_hdgl_base else 'Base-10'}")

    # --- Arithmetic (CPU) ---
    def hdgl_add(self,a,b): return a+b+self.upper_field["phi"]
    def hdgl_sub(self,a,b): return a-b+self.lower_field["inv_phi"]
    def hdgl_mul(self,a,b): return a*b+self.upper_field["P3"]
    def hdgl_div(self,a,b): return float('inf') if b==0 else a/b+self.lower_field["inv_P3"]

    def base10_add(self,a,b): return a+b
    def base10_sub(self,a,b): return a-b
    def base10_mul(self,a,b): return a*b
    def base10_div(self,a,b): return float('inf') if b==0 else a/b

    # --- Functions (CPU) ---
    def base10_sin(self,x): return math.sin(x)
    def base10_cos(self,x): return math.cos(x)
    def base10_tan(self,x): return math.tan(x)
    def base10_exp(self,x): return math.exp(x)
    def base10_log(self,x): return math.log(x) if x>0 else float('nan')

    def hdgl_sin(self,x): return math.sin(x)+self.upper_field["phi"]
    def hdgl_cos(self,x): return math.cos(x)+self.lower_field["inv_phi"]
    def hdgl_tan(self,x): return math.tan(x)+self.upper_field["P3"]
    def hdgl_exp(self,x): return math.exp(x)+self.upper_field["phi_power_phi"]
    def hdgl_log(self,x): return (math.log(x) if x>0 else float('nan'))+self.lower_field["inv_P3"]

    # --- Primitive getter ---
    def get_primitive(self,name):
        return {
            "phi": self.upper_field["phi"],
            "phi^phi": self.upper_field["phi_power_phi"],
            "pi": self.upper_field["pi"],
            "P3": self.upper_field["P3"],
            "P4": self.analog_dims["P4"]
        }.get(name,0)

    # --- Expression evaluation (CPU) ---
    def evaluate_expression(self, expr):
        for key in ["phi^phi","phi","pi","P3","P4"]:
            expr = expr.replace(key,f"({self.get_primitive(key)})")
        tokens = re.findall(r"\d+\.\d+|\d+|[+\-*/()]|sin|cos|tan|exp|log", expr)
        if not tokens: return 0.0
        def parse(t):
            return float(t) if re.match(r"^\d",t) else t
        tokens = [parse(t) for t in tokens]

        if self.use_hdgl_base:
            add,sub,mul,div = self.hdgl_add,self.hdgl_sub,self.hdgl_mul,self.hdgl_div
            funcs = {
                "sin": self.hdgl_sin,
                "cos": self.hdgl_cos,
                "tan": self.hdgl_tan,
                "exp": self.hdgl_exp,
                "log": self.hdgl_log
            }
        else:
            add,sub,mul,div = self.base10_add,self.base10_sub,self.base10_mul,self.base10_div
            funcs = {
                "sin": self.base10_sin,
                "cos": self.base10_cos,
                "tan": self.base10_tan,
                "exp": self.base10_exp,
                "log": self.base10_log
            }

        def eval_tokens(toklist):
            while "(" in toklist:
                close = toklist.index(")")
                open_idx = max(i for i in range(close) if toklist[i]=="(")
                toklist[open_idx:close+1] = [eval_tokens(toklist[open_idx+1:close])]
            i = 0
            while i < len(toklist):
                if isinstance(toklist[i], str) and toklist[i] in funcs:
                    fn = funcs[toklist[i]]
                    arg = toklist[i+1]
                    toklist[i:i+2] = [fn(arg)]
                else:
                    i += 1
            i=0
            while i<len(toklist):
                if toklist[i]=="*":
                    toklist[i-1:i+2]=[mul(toklist[i-1],toklist[i+1])]; i-=1
                elif toklist[i]=="/":
                    toklist[i-1:i+2]=[div(toklist[i-1],toklist[i+1])]; i-=1
                else: i+=1
            i=0
            while i<len(toklist):
                if toklist[i]=="+":
                    toklist[i-1:i+2]=[add(toklist[i-1],toklist[i+1])]; i-=1
                elif toklist[i]=="-":
                    toklist[i-1:i+2]=[sub(toklist[i-1],toklist[i+1])]; i-=1
                else: i+=1
            return toklist[0]

        try:
            return float(eval_tokens(tokens))
        except:
            return float('nan')

    # --- Harmonic state (for CPU fallback, not used if GPU active) ---
    def compute_harmonic_state(self,t):
        state = self.void
        for v in self.upper_field.values():
            state += v * math.sin(t * self.upper_field["phi"])
        for v in self.lower_field.values():
            state += v * math.cos(t * self.lower_field["inv_phi"])
        for v in self.analog_dims.values():
            state += v * math.sin(t * self.upper_field["phi_power_phi"])
        if self.dimension == 1.0:
            state *= math.sin(t * self.upper_field["pi"])
        else:
            state *= math.cos(t * self.upper_field["pi"])
        state *= self.seed
        return state

    def step(self,t, use_gpu=False, compute_shader=None):
        if use_gpu and compute_shader:
            # GPU compute: Return state from shader (set externally)
            self.current_state = compute_shader(t, self.seed, self.dimension,
                                               self.upper_field_values, self.lower_field_values,
                                               self.analog_dims_values,
                                               self.upper_field["phi"], self.lower_field["inv_phi"],
                                               self.upper_field["phi_power_phi"], self.upper_field["pi"])
            if self.recursion_active:
                self.current_state *= self.analog_dims["P7"] / self.lower_field["inv_P7"]
        else:
            # CPU fallback
            self.current_state = self.compute_harmonic_state(t)
            if self.recursion_active:
                self.current_state *= self.analog_dims["P7"] / self.lower_field["inv_P7"]
        return self.current_state

    def kick(self,value):
        self.seed = value if value != 0 else 1.0

# --- GPU-Accelerated OpenGL Visualizer with Compute Shader ---
class HDGLSuperVisualizer(OpenGLFrame):
    def __init__(self, machine, dt=0.05, window=200):
        self.machine = machine
        self.dt = dt
        self.t = 0.0
        self.window = window
        self.times = []
        self.values = []
        self.expression = ""

        # Root Tkinter window
        self.root = tk.Tk()
        self.root.title("HDGL Super Visualizer + Calculator (GPU OpenGL + Compute)")
        self.root.geometry("1000x700")

        # Display
        self.display = tk.Entry(self.root, font=("Consolas", 18), bd=5, relief="sunken", justify="right")
        self.display.pack(fill="x", padx=10, pady=5)

        # Buttons
        btn_frame = ttk.Frame(self.root)
        btn_frame.pack(padx=10, pady=5)
        buttons = [
            "7", "8", "9", "/", "C",
            "4", "5", "6", "*", "(",
            "1", "2", "3", "-", ")",
            "0", ".", "=", "+", " )"
        ]
        for i, t in enumerate(buttons):
            b = ttk.Button(btn_frame, text=t, command=lambda x=t: self.on_button(x))
            b.grid(row=i//5, column=i%5, sticky="nsew", padx=2, pady=2)
        for i in range(5):
            btn_frame.columnconfigure(i, weight=1)
            btn_frame.rowconfigure(i, weight=1)

        # Primitives + Controls
        func_frame = ttk.Frame(self.root)
        func_frame.pack(padx=10, pady=5)
        funcs = ["phi", "phi^phi", "pi", "P3", "P4", "sin", "cos", "tan", "exp", "log", "Recursion", "Toggle Base"]
        for i, t in enumerate(funcs):
            if t == "Recursion":
                b = ttk.Button(func_frame, text=t, command=self.machine.toggle_recursion)
            elif t == "Toggle Base":
                b = ttk.Button(func_frame, text=t, command=self.machine.toggle_base)
            else:
                b = ttk.Button(func_frame, text=t, command=lambda x=t: self.on_button(x))
            b.grid(row=0, column=i, sticky="nsew", padx=2, pady=2)
        for i in range(len(funcs)):
            func_frame.columnconfigure(i, weight=1)

        # OpenGL Frame
        super().__init__(self.root, width=800, height=400)
        self.pack(fill="both", expand=True, padx=10, pady=5)
        self.animate = True

        # Initial data point
        self.update_data()

        # Tkinter loop
        self.root.after(50, self.tk_loop)

    def initgl(self):
        # Set up OpenGL
        GL.glViewport(0, 0, self.width, self.height)
        GL.glMatrixMode(GL.GL_PROJECTION)
        GL.glLoadIdentity()
        GL.glOrtho(0, self.window * self.dt, -500, 500, -1, 1)
        GL.glMatrixMode(GL.GL_MODELVIEW)
        GL.glLoadIdentity()
        GL.glClearColor(0.0, 0.0, 0.0, 1.0)
        GL.glLineWidth(2.0)
        GL.glEnable(GL.GL_BLEND)
        GL.glBlendFunc(GL.GL_SRC_ALPHA, GL.GL_ONE_MINUS_SRC_ALPHA)

        # Initialize compute shader
        self.compute_shader_program = self.init_compute_shader()
        self.ssbo_upper = GL.glGenBuffers(1)
        self.ssbo_lower = GL.glGenBuffers(1)
        self.ssbo_analog = GL.glGenBuffers(1)
        self.ssbo_output = GL.glGenBuffers(1)

        # Bind field data to SSBOs
        GL.glBindBuffer(GL.GL_SHADER_STORAGE_BUFFER, self.ssbo_upper)
        GL.glBufferData(GL.GL_SHADER_STORAGE_BUFFER, self.machine.upper_field_values.nbytes, self.machine.upper_field_values, GL.GL_STATIC_DRAW)
        GL.glBindBuffer(GL.GL_SHADER_STORAGE_BUFFER, self.ssbo_lower)
        GL.glBufferData(GL.GL_SHADER_STORAGE_BUFFER, self.machine.lower_field_values.nbytes, self.machine.lower_field_values, GL.GL_STATIC_DRAW)
        GL.glBindBuffer(GL.GL_SHADER_STORAGE_BUFFER, self.ssbo_analog)
        GL.glBufferData(GL.GL_SHADER_STORAGE_BUFFER, self.machine.analog_dims_values.nbytes, self.machine.analog_dims_values, GL.GL_STATIC_DRAW)
        GL.glBindBuffer(GL.GL_SHADER_STORAGE_BUFFER, self.ssbo_output)
        GL.glBufferData(GL.GL_SHADER_STORAGE_BUFFER, 4, None, GL.GL_DYNAMIC_READ)
        GL.glBindBuffer(GL.GL_SHADER_STORAGE_BUFFER, 0)

    def init_compute_shader(self):
        compute_shader_source = """
        #version 430
        layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
        layout(std430, binding = 0) buffer UpperField { float upper_field[]; };
        layout(std430, binding = 1) buffer LowerField { float lower_field[]; };
        layout(std430, binding = 2) buffer AnalogDims { float analog_dims[]; };
        layout(std430, binding = 3) buffer Output { float state; };
        uniform float t;
        uniform float seed;
        uniform float dimension;
        uniform float phi;
        uniform float inv_phi;
        uniform float phi_power_phi;
        uniform float pi_val;
        void main() {
            float state = 0.0;
            for (int i = 0; i < upper_field.length(); i++)
                state += upper_field[i] * sin(t * phi);
            for (int i = 0; i < lower_field.length(); i++)
                state += lower_field[i] * cos(t * inv_phi);
            for (int i = 0; i < analog_dims.length(); i++)
                state += analog_dims[i] * sin(t * phi_power_phi);
            state *= (dimension == 1.0) ? sin(t * pi_val) : cos(t * pi_val);
            state *= seed;
            state_out = state;
        }
        """
        shader = compileShader(compute_shader_source, GL.GL_COMPUTE_SHADER)
        return compileProgram(shader)

    def compute_harmonic_state_gpu(self, t, seed, dimension, upper_field, lower_field, analog_dims, phi, inv_phi, phi_power_phi, pi_val):
        GL.glUseProgram(self.compute_shader_program)
        GL.glBindBufferBase(GL.GL_SHADER_STORAGE_BUFFER, 0, self.ssbo_upper)
        GL.glBindBufferBase(GL.GL_SHADER_STORAGE_BUFFER, 1, self.ssbo_lower)
        GL.glBindBufferBase(GL.GL_SHADER_STORAGE_BUFFER, 2, self.ssbo_analog)
        GL.glBindBufferBase(GL.GL_SHADER_STORAGE_BUFFER, 3, self.ssbo_output)

        # Set uniforms
        GL.glUniform1f(GL.glGetUniformLocation(self.compute_shader_program, "t"), t)
        GL.glUniform1f(GL.glGetUniformLocation(self.compute_shader_program, "seed"), seed)
        GL.glUniform1f(GL.glGetUniformLocation(self.compute_shader_program, "dimension"), dimension)
        GL.glUniform1f(GL.glGetUniformLocation(self.compute_shader_program, "phi"), phi)
        GL.glUniform1f(GL.glGetUniformLocation(self.compute_shader_program, "inv_phi"), inv_phi)
        GL.glUniform1f(GL.glGetUniformLocation(self.compute_shader_program, "phi_power_phi"), phi_power_phi)
        GL.glUniform1f(GL.glGetUniformLocation(self.compute_shader_program, "pi_val"), pi_val)

        # Dispatch compute shader
        GL.glDispatchCompute(1, 1, 1)
        GL.glMemoryBarrier(GL.GL_SHADER_STORAGE_BARRIER_BIT)

        # Read output
        GL.glBindBuffer(GL.GL_SHADER_STORAGE_BUFFER, self.ssbo_output)
        state = np.frombuffer(GL.glGetBufferSubData(GL.GL_SHADER_STORAGE_BUFFER, 0, 4), dtype=np.float32)[0]
        GL.glBindBuffer(GL.GL_SHADER_STORAGE_BUFFER, 0)
        return state

    def redraw(self):
        GL.glClear(GL.GL_COLOR_BUFFER_BIT)
        GL.glColor3f(0.0, 1.0, 1.0)  # Cyan line

        if len(self.times) > 1:
            GL.glBegin(GL.GL_LINE_STRIP)
            for t_val, v_val in zip(self.times, self.values):
                GL.glVertex2f(t_val, v_val)
            GL.glEnd()

        self.update_data()

    def update_data(self):
        self.t += self.dt
        val = self.machine.step(self.t, use_gpu=True, compute_shader=self.compute_harmonic_state_gpu)
        self.times.append(self.t)
        self.values.append(val)

        if len(self.times) > self.window:
            self.times.pop(0)
            self.values.pop(0)

        GL.glMatrixMode(GL.GL_PROJECTION)
        GL.glLoadIdentity()
        GL.glOrtho(self.times[0] if self.times else 0, self.times[-1] if self.times else self.window * self.dt, -500, 500, -1, 1)
        GL.glMatrixMode(GL.GL_MODELVIEW)

    def tk_loop(self):
        self.root.update()
        self.root.after(50, self.tk_loop)

    def on_button(self, char):
        if char == "=":
            result = self.machine.evaluate_expression(self.expression)
            self.display.delete(0, tk.END)
            self.display.insert(0, str(result))
            self.expression = str(result)
            self.machine.kick(result)
        elif char == "C":
            self.expression = ""
            self.display.delete(0, tk.END)
        else:
            self.expression += str(char)
            self.display.delete(0, tk.END)
            self.display.insert(0, self.expression)

    def run(self):
        print("HDGL Controls: Recursion = toggle, Base = toggle | GPU OpenGL + Compute Shader Active")
        self.mainloop()

# --- Run ---
if __name__ == "__main__":
    machine = HDGLMachine()
    app = HDGLSuperVisualizer(machine)
    app.run()